{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trainer control\n",
    "千帆Python SDK 在使用[trainer 实现训练微调](./trainer_finetune_dataset2deploy.ipynb)的基础上，SDK还提供了灵活的事件回调、以及trainer的可恢复的特性，以下以新建训练任务，并注册EventHandler，遇到报错之后进行resume进行演示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! pip install \"qianfan>=0.2.2\" -U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'0.2.2'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import qianfan\n",
    "qianfan.__version__"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前置准备\n",
    "- 初始化千帆安全认证AK、SK"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "\n",
    "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_ak\"\n",
    "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_sk\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 导入依赖\n",
    "- `qianfan.trainer.consts` trainer使用中所用到的常量\n",
    "- `qianfan.resources.console.consts` api层面定义的字段常量\n",
    "- `qianfan.trainer.configs` trainer使用所需要的config配置数据类\n",
    "- `qianfan.trainer.LLMFinetune` 大语言模型fine-tune任务Trainer实现\n",
    "- `qianfan.trainer.Service` service类，用于表示平台的模型服务，可以通过trainer.result获取\n",
    "- `qianfan.dataset.Dataset` 千帆dataset类，用于管理千帆平台、本地、第三方数据集的导入导出，数据清洗等操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.trainer.consts import ActionState\n",
    "from qianfan.model.consts import ServiceType\n",
    "from qianfan.resources.console import consts as console_consts\n",
    "from qianfan.trainer.configs import TrainConfig\n",
    "from qianfan.model.configs import DeployConfig\n",
    "from qianfan.resources import QfMessages\n",
    "from qianfan.trainer import LLMFinetune\n",
    "from qianfan.model import Service\n",
    "from qianfan.dataset import Dataset\n",
    "from typing import cast\n",
    "from qianfan.utils import enable_log\n",
    "import logging\n",
    "\n",
    "enable_log(logging.INFO)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## EventHandler\n",
    "\n",
    "如果需要在训练过程中监控每个阶段的各个节点的状态，可以通过事件回调函数来实现，通过事件的对应的action_state可以获取当前的action的运行情况以实现对应的业务回调，插入自定义逻辑"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.model import Model\n",
    "from qianfan.dataset import Dataset\n",
    "from qianfan.trainer.consts import PeftType\n",
    "\n",
    "# 首先需要先加载测试数据集，这里以加载平台预置数据集为例子：\n",
    "ds = Dataset.load(qianfan_dataset_id=\"ds-xxx\")\n",
    "trainer = LLMFinetune(\n",
    "    train_type=\"ERNIE-Speed-8K\",\n",
    "    train_config=TrainConfig(\n",
    "        epoch=1,\n",
    "        learning_rate=0.0003,\n",
    "        max_seq_len=4096,\n",
    "        peft_type=PeftType.LoRA,\n",
    "        logging_steps=1,\n",
    "        warmup_ratio=0.10,\n",
    "        weight_decay=0.0100,\n",
    "        lora_rank=8,\n",
    "        lora_all_linear=\"True\",\n",
    "    ),\n",
    "    dataset=ds,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qianfan.trainer.event import Event, EventHandler\n",
    "\n",
    "testset: Dataset = Dataset.load(data_file=\"./data/fin_cqa_test.jsonl\")\n",
    "# 定义自己的EventHandler，并实现dispatch方法\n",
    "class InferAfterSFT(EventHandler):\n",
    "    target_action: str\n",
    "    def __init__(self, target_action: str) -> None:\n",
    "        super().__init__()\n",
    "        self.target_action = target_action\n",
    "\n",
    "    def dispatch(self, event: Event) -> None:\n",
    "        print(\"receive: <\", event)\n",
    "        if self.target_action == event.action_id and event.action_state == ActionState.Done:\n",
    "            svc = cast(Service, event.data[\"service\"])\n",
    "            print(\"svc\", svc)\n",
    "            for row in testset.list():\n",
    "                msgs = QfMessages()\n",
    "                msgs.append(row[0][0][\"prompt\"], \"user\")\n",
    "                svc.exec({\"messages\":\"msgs\"})\n",
    "                print(\"row infer result\", row)\n",
    "            \n",
    "\n",
    "eh = InferAfterSFT(target_action=trainer.ppls[0].id)\n",
    "trainer.register_event_handler(eh)\n",
    "trainer.run()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 任务恢复\n",
    "\n",
    "针对网络中断，服务不稳定等重试无法覆盖的场景，SDK提供了`resume()`以恢复训练过程，这里以LLMFinetune中断后恢复为例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 21:54:28] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:30] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:33] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:38] data_source.py:1044 [t:139789057857344]: data releasing, keep rolling\n",
      "[INFO] [12-07 21:54:41] data_source.py:1053 [t:139789057857344]: data releasing succeeded\n",
      "[INFO] [12-07 21:54:44] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 21:55:14] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 21:55:46] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[ERROR] [12-07 21:56:17] console_requestor.py:42 [t:139789057857344]: console api request failed with error code: 500002, err msg: auth failed, no access, please check the api doc\n"
     ]
    },
    {
     "ename": "APIError",
     "evalue": "api return error, code: 500002, msg: auth failed, no access",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n",
      "\u001b[0;31mAPIError\u001b[0m                                  Traceback (most recent call last)\n",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n",
      "\u001b[0;32m----> 1\u001b[0m trainer\u001b[39m.\u001b[39mrun()\n",
      "\n",
      "File \u001b[0;32m/data/anaconda3/lib/python3.11/site-packages/qianfan/trainer/finetune.py:173\u001b[0m, in \u001b[0;36mLLMFinetune.run\u001b[0;34m(self, **kwargs)\u001b[0m\n",
      "\u001b[1;32m    167\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mbackoff_factor\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mget(\n",
      "\u001b[1;32m    168\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mbackoff_factor\u001b[39m\u001b[39m\"\u001b[39m, get_config()\u001b[39m.\u001b[39mTRAINER_STATUS_POLLING_BACKOFF_FACTOR\n",
      "\u001b[1;32m    169\u001b[0m )\n",
      "\u001b[1;32m    170\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mretry_count\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m kwargs\u001b[39m.\u001b[39mget(\n",
      "\u001b[1;32m    171\u001b[0m     \u001b[39m\"\u001b[39m\u001b[39mretry_count\u001b[39m\u001b[39m\"\u001b[39m, get_config()\u001b[39m.\u001b[39mTRAINER_STATUS_POLLING_RETRY_TIMES\n",
      "\u001b[1;32m    172\u001b[0m )\n",
      "\u001b[0;32m--> 173\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mresult[\u001b[39m0\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mppls[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mexec(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
      "\u001b[1;32m    174\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n",
      "\n",
      "File \u001b[0;32m/data/anaconda3/lib/python3.11/site-packages/qianfan/resources/requestor/console_requestor.py:46\u001b[0m, in \u001b[0;36mConsoleAPIRequestor._check_error\u001b[0;34m(self, body)\u001b[0m\n",
      "\u001b[1;32m     41\u001b[0m err_msg \u001b[39m=\u001b[39m body\u001b[39m.\u001b[39mget(\u001b[39m\"\u001b[39m\u001b[39merror_msg\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mno error message found in response body\u001b[39m\u001b[39m\"\u001b[39m)\n",
      "\u001b[1;32m     42\u001b[0m log_error(\n",
      "\u001b[1;32m     43\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mconsole api request failed with error code: \u001b[39m\u001b[39m{\u001b[39;00merror_code\u001b[39m}\u001b[39;00m\u001b[39m, err msg:\u001b[39m\u001b[39m\"\u001b[39m\n",
      "\u001b[1;32m     44\u001b[0m     \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00merr_msg\u001b[39m}\u001b[39;00m\u001b[39m, please check the api doc\u001b[39m\u001b[39m\"\u001b[39m\n",
      "\u001b[1;32m     45\u001b[0m )\n",
      "\u001b[0;32m---> 46\u001b[0m \u001b[39mraise\u001b[39;00m errors\u001b[39m.\u001b[39mAPIError(error_code, err_msg)\n",
      "\n",
      "\u001b[0;31mAPIError\u001b[0m: api return error, code: 500002, msg: auth failed, no access"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 22:00:58] actions.py:390 [t:139789057857344]: [train_action] resume from created job 17304/9077\n",
      "[INFO] [12-07 22:00:58] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:01:29] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:02:00] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:02:30] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:03:01] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:03:31] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:04:02] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:17:51] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:18:22] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:18:52] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:19:23] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:19:54] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:32:12] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:32:43] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:33:14] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:33:44] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:34:15] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:34:46] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:35:17] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:35:48] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: RUNNING, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:36:18] actions.py:352 [t:139789057857344]: [train_action] fine-tune running... current status: FINISH, check vdl report in https://console.bce.baidu.com/qianfan/visualdl/index?displayToken=eyJydW5JZCI6InJ1bi10MnlzaWQ3NjE1Z3N0Zm11In0=\n",
      "[INFO] [12-07 22:36:18] actions.py:370 [t:139789057857344]: [train_action] fine-tune job has ended: 9077 with status: FINISH\n",
      "[INFO] [12-07 22:36:19] model.py:183 [t:139789057857344]: check train job: 17304/9077 status before publishing model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INFO] [12-07 22:36:20] model.py:199 [t:139789057857344]: model publishing keep polling, current status FINISH\n",
      "[INFO] [12-07 22:36:20] model.py:233 [t:139789057857344]: model ready to publish\n",
      "[INFO] [12-07 22:36:21] model.py:239 [t:139789057857344]: check model publish status: Creating\n",
      "[INFO] [12-07 22:36:51] model.py:239 [t:139789057857344]: check model publish status: Ready\n",
      "[INFO] [12-07 22:36:51] model.py:241 [t:139789057857344]: model 10248/12701 published successfully\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<qianfan.trainer.finetune.LLMFinetune at 0x7f22c4dd0210>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.resume()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
